{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tensorflow的常见Metrics评价函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***含义：***  \n",
    "评价函数Metrics用于评估当前训练模型的性能\n",
    "\n",
    "***和损失函数Loss Function相同的地方：***  \n",
    "评价函数和 损失函数 相似，函数声明都是func(y_true, y_pred)\n",
    "\n",
    "***和损失函数Loss Function不同的地方：***  \n",
    "1. 评价函数的结果不会用于训练过程中\n",
    "2. 评价函数是有状态的\n",
    "\n",
    "跟损失函数又不一样，评价函数会跨训练batch记录数据计算指标：  \n",
    "比如为了计算Accuracy，会维护total变量记录见过的记录数，维护count变量记录正确匹配的记录数  \n",
    "计算公式：Accuracy = count/total\n",
    "\n",
    "* 使用update_state函数更新数据状态（更新count、total）\n",
    "* 使用result函数查询metric数值（查询迄今为止的最终数值count/total）\n",
    "* 使用reset_states清空状态（total=count=0）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 评价函数Metrics的使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以传递已有的评价函数名称，或者传递一个自定义的TensorFlow函数来使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Dense(32, activation='relu'),\n",
    "    tf.keras.layers.Dense(1, activation='sigmoid')\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 方法1：传入字符串名称\n",
    "model.compile(loss='mean_squared_error',\n",
    "              optimizer='sgd',\n",
    "              metrics=['accuracy', 'auc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 方法2：传入函数名列表\n",
    "model.compile(loss='mean_squared_error',\n",
    "              optimizer='sgd',\n",
    "              metrics=[tf.keras.metrics.Accuracy(), tf.keras.metrics.AUC()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 编写自己的Metrics评价函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_metric_fn(y_true, y_pred):\n",
    "    \"\"\"自己的Metrics评价函数\"\"\"\n",
    "    return tf.reduce_mean(tf.square(y_true - y_pred), axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float64, numpy=0.1875>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 做下测试\n",
    "my_metric_fn(np.array([0, 0, 1, 1]), np.array([0, 0.5, 0.3, 0.9]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用方式\n",
    "model.compile(optimizer='adam', loss='mean_squared_error', metrics=[my_metric_fn])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 常见的分类评价函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 准确率（Accuracy）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.75"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tf.keras.metrics.Accuracy()\n",
    "m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]])\n",
    "m.result().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### AUC（Area Under Curve）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.75"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tf.keras.metrics.AUC()\n",
    "m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])\n",
    "m.result().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 精确率（Precision）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6666667"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tf.keras.metrics.Precision()\n",
    "m.update_state([0, 1, 1, 1], [1, 0, 1, 1])\n",
    "m.result().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 召回率（Recall）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6666667"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tf.keras.metrics.Recall()\n",
    "m.update_state([0, 1, 1, 1], [1, 0, 1, 1])\n",
    "m.result().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. 常见的回归评价函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### MSE（Mean Square Error）均方误差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.25"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tf.keras.metrics.MeanSquaredError()\n",
    "m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n",
    "m.result().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### RMSE（Root Mean Square Error）均方根误差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tf.keras.metrics.RootMeanSquaredError()\n",
    "m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n",
    "m.result().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### MAE（Mean Absolute Error）平均绝对误差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.25"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = tf.keras.metrics.MeanAbsoluteError()\n",
    "m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n",
    "m.result().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
